import csv
import os
import re
import time
import argparse
import pandas as pd
from tqdm import tqdm
import openai
from openai import OpenAI

from google import genai
from google.genai import types as genai_types

# Import templates from preference_templates.py
from preference_templates import (
    system_prompt, 
    instruction_following_template, 
    honesty_template, 
    truthfulness_template,
    helpfulness_template
)
# Default templates dictionary
TEMPLATES = {
    "instruction_following": instruction_following_template,
    "honesty": honesty_template,
    "truthfulness": truthfulness_template,
    "helpfulness": helpfulness_template
}


def get_eval(instruction, response, preference, api_key=None, max_retries=5, provider="openai"):
    """Call API to get evaluation for a prompt-response pair
    
    Args:
        instruction: The instruction prompt 
        response: The generated response to evaluate
        preference: Type of evaluation (instruction_following, honesty, truthfulness, helpfulness)
        api_key: API key for the provider
        max_retries: Maximum number of retry attempts
        provider: Model provider ("openai" or "deepseek" or "gemini")
    """
    # Get the template for the specific preference
    template = TEMPLATES.get(preference)
    if not template:
        raise ValueError(f"Unknown preference: {preference}")
    
    # Format the template with the instruction and response
    user_prompt = template.format(instruction=instruction, response=response)
    
    # Configure based on provider and call appropriate API
    if provider == "deepseek":
        base_url = "https://api.deepseek.com"
        model = "deepseek-chat"
        use_openai = True
    elif provider == "openai":
        base_url = None
        model = "gpt-4.1"
        use_openai = True
    elif provider == "gemini":
        use_openai = False
        gemini_model = "gemini-2.5-flash-preview-04-17"
    else:
        raise ValueError(f"Unknown provider: {provider}")

    if use_openai:
        client = OpenAI(api_key=api_key, base_url=base_url)
        for _ in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt}
                    ],
                    temperature=0,
                    max_tokens=500,
                    top_p=0.6,
                    presence_penalty=0,
                    frequency_penalty=0
                )
                return response.choices[0].message.content
            except Exception as e:
                print(f"Error: {e}")
                time.sleep(1)
    else:
        # Gemini provider
        client = genai.Client(api_key=api_key)
        for _ in range(max_retries):
            try:
                resp = client.models.generate_content(
                    model=gemini_model,
                    contents=user_prompt,
                    config=genai_types.GenerateContentConfig(
                        system_instruction=system_prompt,
                        thinking_config=genai_types.ThinkingConfig(thinking_budget=0)
                    )
                )
                return resp.text
            except Exception as e:
                print(f"Error: {e}")
                time.sleep(1)

    return "Error: Failed to get evaluation after multiple retries"

def process_response(response_text, preference):
    """Process the API response to extract rating and rationale"""
    try:
        # All templates now have the same pattern
        pattern = r"Rating: (.+?)\nRationale: (.+)"
        matches = re.search(pattern, response_text, re.DOTALL)
        if matches:
            rating_text = matches.group(1).strip()
            # Check for N/A or non-numeric ratings first
            if rating_text == "N/A" or "N/A" in rating_text:
                rating = "N/A"
            else:
                # Extract numeric rating if present
                numeric_matches = re.findall(r'\b\d+\b', rating_text)
                rating = numeric_matches[0] if numeric_matches else rating_text
                
            return {
                "Rating": rating,
                "Rationale": matches.group(2).strip()
            }
    except Exception as e:
        print(f"Error processing response: {e}")
        print(f"Response text: {response_text}")
    
    # Fallback for parsing errors
    return {"Error": "Failed to parse response", "Raw": response_text}

def evaluate_csv(input_file, preferences, max_rows, output_file=None, api_key=None, provider="openai"):
    """Evaluate prompt-response pairs from a CSV file"""
    if not output_file:
        base, ext = os.path.splitext(input_file)
        output_file = f"{base}_evaluated_{provider}{ext}"
    
    # Read input CSV
    df = pd.read_csv(input_file)

    max_rows = min(max_rows, len(df))
    total_rows = len(df)

    df = df.head(max_rows)  # Only evaluate the first max_rows rows

    print(f"We have {max_rows} out of {total_rows} rows to evaluate")
    
    # Add columns for evaluations
    for preference in preferences:
        df[f"{preference}_score"] = None
        df[f"{preference}_rationale"] = None
    
    # Process each row
    for idx, row in tqdm(df.iterrows(), total=max_rows, desc="Evaluating"):
        for preference in preferences:
            if isinstance(row['prompt'], str):
                print(f"Evaluating {preference} for prompt: {row['prompt'][:30]}...")
            else:
                print(f"Evaluating {preference} for prompt: {row['prompt']}")
            
            # Get evaluation - one API call per prompt & response & preference
            eval_response = get_eval(
                row['prompt'], 
                row['response'], 
                preference, 
                api_key=api_key,
                provider=provider
            )
            
            print(f"Evaluation response: {eval_response}")

            # Process response
            result = process_response(eval_response, preference)

            print(f"Result: {result}")
            
            # Update dataframe
            if 'Rating' in result:
                df.at[idx, f"{preference}_score"] = result["Rating"]
                df.at[idx, f"{preference}_rationale"] = result["Rationale"]
            else:
                df.at[idx, f"{preference}_score"] = "Error"
                df.at[idx, f"{preference}_rationale"] = result.get("Raw", "Failed to process")
    
    # Calculate and add average scores
    avg_scores_row = {"prompt": "Average Scores"}
    for col in df.columns:
        if col not in ["prompt", "response"] and "_score" not in col and "_rationale" not in col:
            avg_scores_row[col] = "N/A" # Or some other placeholder for non-score/rationale columns

    for preference in preferences:
        score_column = f"{preference}_score"
        # Convert to numeric, coercing errors to NaN, then calculate mean
        numeric_scores = pd.to_numeric(df[score_column], errors='coerce')
        if not numeric_scores.isnull().all(): # Check if there are any valid numeric scores
            avg_score = numeric_scores.mean()
            avg_scores_row[score_column] = round(avg_score, 2) if pd.notnull(avg_score) else "N/A"
        else:
            avg_scores_row[score_column] = "N/A" # If all are non-numeric or column is empty
        # Add N/A for rationale of average row
        avg_scores_row[f"{preference}_rationale"] = "N/A"

    # Convert the dictionary to a DataFrame before concatenating
    avg_df = pd.DataFrame([avg_scores_row])
    df = pd.concat([df, avg_df], ignore_index=True)

    # Save results
    df.to_csv(output_file, index=False)
    print(f"Evaluation complete. Results saved to {output_file}")
    return df

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate LLM responses with specified preferences")
    parser.add_argument("--input_file", help="CSV file with prompt,response columns")
    parser.add_argument("--preferences", nargs="+", default=["honesty", "helpfulness", "instruction_following"],
                        choices=["instruction_following", "honesty", "truthfulness", "helpfulness"],
                        help="Preferences to evaluate")
    parser.add_argument("--output_file", help="Output CSV file (default: input_file with _evaluated suffix)")
    parser.add_argument("--api_key", help="API key")
    parser.add_argument("--provider", default="openai", choices=["openai", "deepseek", "gemini"],
                        help="Model provider (openai = gpt-4.1, deepseek = deepseek-chat, gemini = gemini-2.5-flash)")
    parser.add_argument("--max_rows", type=int, default=200, help="Maximum number of rows to evaluate")
    
    args = parser.parse_args()
    
    print("Selected preferences: ", args.preferences)

    # Set API key from environment variable if not provided
    if not args.api_key:
        if args.provider == "openai":
            env_var = "OPENAI_API_KEY"
        elif args.provider == "deepseek":
            env_var = "DEEPSEEK_API_KEY"
        elif args.provider == "gemini":
            env_var = "GOOGLE_API_KEY"
        else:
            env_var = None
        args.api_key = os.environ.get(env_var) if env_var else None
        if not args.api_key:
            print(f"Warning: No {args.provider} API key provided. Please set {env_var} environment variable or use --api_key")
        else:
            print(f"Got API key from environment variable")
    
    if not args.input_file:
        parser.error("You must provide an input file using --input_file")
    
    evaluate_csv(
        args.input_file, 
        args.preferences,
        args.max_rows,
        output_file=args.output_file, 
        api_key=args.api_key,
        provider=args.provider
    ) 